{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9e3928bf-5d05-40e3-845b-51d228a76dde",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-246-g3d31191b-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "\n",
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self,x):\n",
    "        return x - x.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fb890fac-5aa0-461b-8a0e-4054cbeee2a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(torch.FloatTensor([1,2,3,4,5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b46ca6a0-7643-4db9-b971-1923a9612fcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(nn.Linear(8,128),CenteredLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "df007b6c-7716-491e-ac83-bf306393a359",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.2573e-08, grad_fn=<MeanBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y  = net(torch.rand(4,8))\n",
    "Y.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7fac3790-0411-4686-b31c-98c81f3de7d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.3846,  0.2432, -0.1644],\n",
       "        [ 1.1808,  0.0035, -0.5359],\n",
       "        [-2.8157, -0.4648, -1.0441],\n",
       "        [ 0.1642, -1.6544, -1.4877],\n",
       "        [-1.4197,  1.7398,  1.1485]], requires_grad=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#5.4.2. 带参数的层\n",
    "class MyLinear(nn.Module):\n",
    "    def __init__(self,in_units,units):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.randn(in_units,units))\n",
    "        self.bias = nn.Parameter(torch.randn(units,))\n",
    "    def forward(self,x):\n",
    "        linear = torch.matmul(x,self.weight.data) + self.bias.data\n",
    "        return F.relu(linear)\n",
    "linear = MyLinear(5,3)\n",
    "linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5045b1c9-8e08-419e-992d-9586f3c9bc80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.1179, 0.0000]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear(torch.rand(2,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dc9e7a00-1644-4073-9b57-dc90b6e87e34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [0.]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(MyLinear(64,8),MyLinear(8,1))\n",
    "net(torch.rand(2,64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dbcb36ce-d0bc-4edd-8f89-78d5e060396e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -0.5622],\n",
       "        [-44.5627],\n",
       "        [ 10.0461],\n",
       "        [  0.3534],\n",
       "        [-18.1458]], grad_fn=<CatBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class DemensionReduction(nn.Module):\n",
    "    def __init__(self,in_units, units):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.randn(in_units*in_units, units))\n",
    "    def forward(self, X):\n",
    "        Y = torch.matmul(X[0,:].reshape(-1,1),X[0,:].reshape(1,-1)).reshape(1,-1)\n",
    "        Z = torch.matmul(Y,self.weight)\n",
    "        for i in range(1,X.shape[0]):\n",
    "            Y = torch.matmul(X[i,:].reshape(-1,1),X[i,:].reshape(1,-1)).reshape(1,-1)\n",
    "            Z = torch.cat((Z,torch.matmul(Y,self.weight)),0)\n",
    "        return Z\n",
    "dem = DemensionReduction(20,1)\n",
    "X = torch.randn(5,20)\n",
    "dem(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f834cca2-25a7-4d20-a49f-35d48100c869",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
